"""tensor的组合/拼接"""
import torch


def test_cat():
    print("==== cat =====")
    a = torch.zeros(2, 4)
    b = torch.ones((2, 4))

    print("a: ")
    print(a)
    print("b: ")
    print(b)
    print()

    print("torch.cat((a,b),dim=1)  : ")
    out = torch.cat((a, b), dim=1)
    print(out)
    print()


def test_stack():
    print("==== stack =====")
    a = torch.linspace(1, 6, 6).view(2, 3)
    b = torch.linspace(7, 12, 6).view(2, 3)

    print("a: ")
    print(a)
    print("b: ")
    print(b)
    print()

    print("torch.stack((a,b),dim=1)  : ")
    out = torch.stack((a, b), dim=1)
    print(out)
    print(out.shape)
    print()

    print("out[:, 0, :]")
    print(out[:, 0, :])
    print("out[:, 1, :]")
    print(out[:, 1, :])


if __name__ == '__main__':
    test_cat()
    test_stack()
